-
Notifications
You must be signed in to change notification settings - Fork 100
呢哇 #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
呢哇 #12
Conversation
Add missing newline at end of file for graph.cc
Removed TODO comments for memory allocation and deallocation.
Updated inferShape method to correctly compute the concatenated shape.
Added a missing newline at the end of the file and updated comments.
Update transpose function to reverse dimensions
Removed TODO comments and added clarification about shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements homework assignments for a deep learning framework, completing several operator shape inference functions, broadcast logic, graph optimization passes, and memory allocation. The PR title "呢哇" (Chinese characters) doesn't describe the changes.
Changes:
- Implements shape inference for Clip, Cast, Transpose, MatMul, and Concat operators
- Implements bidirectional broadcasting shape inference (infer_broadcast)
- Adds graph optimization passes to eliminate redundant transposes and fuse transposes into matmul operations
- Implements memory allocation with first-fit strategy and block coalescing
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/utils/operator_utils.cc | Implements broadcast shape inference and applies formatting changes to namespace braces |
| src/operators/unary.cc | Implements shape inference for Clip and Cast operators |
| src/operators/transpose.cc | Fixes default permutation to reverse dimensions per ONNX spec and implements shape inference with validation |
| src/operators/matmul.cc | Implements shape inference with batch broadcasting and transpose support |
| src/operators/concat.cc | Implements shape inference with dimension validation |
| src/core/graph.cc | Implements graph optimization passes and memory allocation with lifetime analysis |
| src/core/allocator.cc | Implements memory allocator with first-fit allocation and block coalescing |
| include/core/allocator.h | Adds data structures and helper methods for memory management |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| { | ||
| if (r == dim) | ||
| continue; | ||
| IT_ASSERT(cur[r] == dims[r], "Concat dims mismatch on non-concat axis"); |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the other assertion, this message could be more informative by including which permutation element caused the issue. Consider: "Concat dims mismatch on non-concat axis " + std::to_string(r) + ": expected " + std::to_string(dims[r]) + ", got " + std::to_string(cur[r]).
| { | ||
| int p = transposePermute[i]; | ||
| IT_ASSERT(p >= 0 && p < rank); | ||
| IT_ASSERT(seen[p] == 0); |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message for duplicate permutation indices lacks helpful details. Consider: "Duplicate permutation index " + std::to_string(p) + " at position " + std::to_string(i) + " in transpose permutation".
| std::map<size_t, size_t>::iterator Allocator::findFreeBlock(size_t size) | ||
| { | ||
| // first-fit: 找到第一个 size 足够的空闲块 | ||
| for (auto it = freeBlocks.begin(); it != freeBlocks.end(); ++it) | ||
| { | ||
| if (it->second >= size) | ||
| return it; | ||
| } | ||
| return freeBlocks.end(); |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The findFreeBlock function uses a linear search (first-fit) through all free blocks. For workloads with many allocations and deallocations, this could become a performance bottleneck as it has O(n) complexity where n is the number of free blocks. Consider using a more efficient data structure, such as maintaining free blocks sorted by size (in addition to address) or using a best-fit strategy with size-indexed structures for better performance. However, if the number of tensors is small (typical in most ML graphs), this may not be a practical concern.
| Shape dims = inputs[0]->getDims(); | ||
| // =================================== 作业 =================================== | ||
| // TODO:修改 dims,返回正确的 concat 后的 shape | ||
| // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 | ||
| // =================================== 作业 =================================== | ||
|
|
||
| // =================================== 作业 =================================== | ||
| // TODO:修改 dims,返回正确的 concat 后的 shape | ||
| // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 | ||
| // =================================== 作业 =================================== | ||
| IT_ASSERT(!inputs.empty()); |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion that inputs is not empty (line 22) comes after dims is already initialized from inputs[0] (line 16). If inputs were empty, line 16 would access an invalid element before the assertion could catch it. Move the IT_ASSERT(!inputs.empty()) check to before line 16 to ensure the precondition is validated first.
| auto tr = as<TransposeObj>(src); | ||
| if (tr->getOutput() != in) | ||
| continue; | ||
| if (!isSwapLast2Permute(tr->getPermute())) | ||
| continue; | ||
|
|
||
| auto trIn = tr->getInputs(0); | ||
| mm->replaceInput(in, trIn); | ||
| if (inputIdx == 0) | ||
| mm->setTransA(!mm->getTransA()); | ||
| else | ||
| mm->setTransB(!mm->getTransB()); | ||
| toRemove.insert(src.get()); | ||
| changed = true; |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential issue: When a transpose operator is marked for removal and added to the toRemove set, the code doesn't verify that the transpose output is only consumed by this matmul operator. If the transpose output is used by multiple consumers, removing it would break other parts of the graph. Add a check to ensure tr->getOutput()->getTargets().size() == 1 before adding to toRemove, similar to the check in rule 1 at line 157.
| m = aM; | ||
| n = bN; | ||
| k = aK; |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The matmul m, n, k member variables are being modified during inferShape, but this method is marked as non-const and is called during shape inference. These member variables should be considered as operator attributes that are derived from the inputs during inference. However, the issue is that m, n, k are being set as side effects of inferShape, which could lead to issues if inferShape is called multiple times or in different contexts. Consider whether these should be computed on-demand via getter methods instead of being cached as mutable state, or document that these are only valid after inferShape has been called.
| else if (dimB == 1) | ||
| out[i] = dimA; | ||
| else | ||
| IT_ASSERT(false, "Broadcast shape mismatch"); |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion message "Broadcast shape mismatch" on line 31 could be more informative. Consider including the actual dimension values and indices to help with debugging. For example: "Broadcast shape mismatch at dimension " + std::to_string(i) + ": dimA=" + std::to_string(dimA) + ", dimB=" + std::to_string(dimB).
| IT_ASSERT(false, "Broadcast shape mismatch"); | |
| { | |
| std::string msg = "Broadcast shape mismatch at dimension " + | |
| std::to_string(i) + ": dimA=" + | |
| std::to_string(dimA) + ", dimB=" + | |
| std::to_string(dimB); | |
| IT_ASSERT(false, msg); | |
| } |
No description provided.